
import os
import cv2
import time
import random
from PIL import Image
import numpy as np

import torch
import torchvision.transforms as transforms

from models.ResNet_LSTM import Res_LSTM





# Device setting
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
# device = torch.device("cuda")
print("device: ", device)


def read_labels(label_path):
    """读取标签"""
    labels = {}
    label_file = open(label_path, 'r')
    for line in label_file.readlines():
        line = line.strip()
        line = line.split('\t')
        labels[line[0]] = line[1]
    return labels

# Hyperparams
sample_size = 256       # 输入图像尺寸
sample_duration = 16    # 网络的输入帧数
num_classes = 101       # 暂时先训练101个词
lstm_hidden_size = 512
lstm_num_layers = 1
# ResNet网络选择
# arch = "resnet18"
# arch = "resnet34" 
arch = "resnet50" 

if __name__ == '__main__':
    dataset_videos_path = r"G:\嵌入式竞赛-海思赛道\CSL手语数据集\color-gloss\color" # 手语数据集地址
    labels_path = r"D:\Works\Hisi_codes\SLRDataset_isolate\labels_101_en.txt"   # 101个类别的标签
    wights_path = r"D:\Works\Hisi_codes\SLR-master\res50lstm_models\res50lstm_epoch026.pth" # 权重路径

    # transform
    transform = transforms.Compose([transforms.Resize([sample_size, sample_size]),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.5], std=[0.5])])

    # 加载标签
    print("Load labels from : ", labels_path)
    labels = read_labels(labels_path)

    # 加载视频
    dir_list = os.listdir(dataset_videos_path) # 所有文件夹名
    dir_index = random.randint(0, num_classes) # 随机选择一个类
    videos_list = os.listdir(os.path.join(dataset_videos_path, dir_list[dir_index])) # 取出选中类内所有视频名称
    video_index = random.randint(len(videos_list)*0.5, len(videos_list)-1) # 随机在该类选择一个视频
    video_path = os.path.join(dataset_videos_path, dir_list[dir_index], videos_list[video_index]) # 生成路径
    capture = cv2.VideoCapture(video_path)

    frames_num = capture.get(cv2.CAP_PROP_FRAME_COUNT) # 视频总帧数
    step = int(frames_num/sample_duration) # 抽帧步长

    # video_label = labels["{:03d}".format(dir_index)] # 当前视频的标签
    
    # 创建网络
    print("Creating model...", end="")
    model = Res_LSTM(sample_size=sample_size, sample_duration=sample_duration, num_classes=num_classes, arch=arch).to(device)
    print("\tCreate model finshed")

    # 加载权重
    print("Load weights from : ", wights_path)
    model.load_state_dict(torch.load(wights_path))

    fps = 0.0
    count = 0 # 帧数计数
    images = [] # 网络的暂存输入数据
    while(True):
        t1 = time.time()
        # 读取某一帧
        ref, frame = capture.read()
        if not ref:
            break
        count = count + 1
        if count%step==0: # 每step帧图像取一帧            
            image = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) # 格式转变，BGRtoRGB
            images.append(transform(Image.fromarray(np.uint8(image)))) # frame转PILimage

        if len(images)==sample_duration: # 满足输入帧数
            input = torch.stack(images, dim=0) # switch dimension for 3d cnn
            input = input.permute(1, 0, 2, 3).reshape([1, 3, sample_duration, sample_size, sample_size]) # 第一个维度为batch，使用reshape加上
            images = images[1:sample_duration] # 去掉第一帧，为下一次输入做准备
            # 预测
            input = input.to(device) # .to(device) - 转GPU模型
            output = model(input)

            # 检测预测结果, 并显示
            pre_index = torch.argmax(output) # 最大值下标
            print("true={:03d}  pre={:03d}".format(dir_index, pre_index))
            if pre_index==dir_index: # 预测正确
                frame = cv2.putText(frame, "true={:03d} pre={:03d}  ".format(dir_index, pre_index) + labels["{:03d}".format(pre_index)], (0, 80), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            # torch.cuda.empty_cache()
        else:
            fps  = ( fps + (1./(time.time()-t1)) ) / 2
            print("fps= %.2f"%(fps))
            frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        
        cv2.imshow(dir_list[dir_index] + "-" + videos_list[video_index], frame)
        c= cv2.waitKey(30) & 0xff 
        if c==27:
            capture.release()
            break


